{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "from Attention is All You Need (https://arxiv.org/pdf/1706.03762)\n",
    "which introduced the notion of multiple \"Attention Heads\" per layer instead of vanilla self-attention \n",
    "which had been around since the Bahdanue NMT paper in 2014 \n",
    "in addition to introducing the Transformer architecture, encoder/decoder, etc\n",
    "The interpretation is that we split up the full residual stream dimension D \n",
    "by hand to force different subspaces to try attend to different things (\"concepts\" in the input)\n",
    "this turns out to be a useful inductive bias, even though in principle it doesn't add any expressive power\n",
    "since the full self-attention could have \"learned\" the subspaces. \n",
    "In practice, different heads learn different aspects of language and different \n",
    "functions, like induction heads (see Anthropic \"A Mathematical Model of Transformer Circuits\")\n",
    "successor heads (see the ICLR 2024 paper by Gould, Ong, et al.), and more. \n",
    "'''\n",
    "\n",
    "import torch \n",
    "import torch.nn as nn \n",
    "import torch.nn.functional as F \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MHSA(nn.Module): \n",
    "    def __init__(self, D, head_dim=64): # we fix head_dim and scale number of heads as D grows\n",
    "        super().__init__()\n",
    "        self.D = D # \"residual stream dimension,\" see 2022 Anthropic papers \n",
    "        self.wq = nn.Linear(D,D) \n",
    "        self.wk = nn.Linear(D,D)\n",
    "        self.wv = nn.Linear(D,D)\n",
    "        self.wo = nn.Linear(D,D)\n",
    "        \n",
    "        assert D % head_dim == 0 \n",
    "        self.n_heads = D//head_dim \n",
    "        self.head_dim = head_dim \n",
    "\n",
    "    def forward(self, x): # BSD -> BSD \n",
    "        B, S, D = x.shape\n",
    "        q, k, v = self.wq(x), self.wk(x), self.wv(x) # BSD -> BSD\n",
    "\n",
    "        # view as B, n_heads, S, head_dim\n",
    "        q = q.view(B, self.n_heads, S, self.head_dim)\n",
    "        k = k.view(B, self.n_heads, S, self.head_dim)\n",
    "        v = v.view(B, self.n_heads, S, self.head_dim)\n",
    "        \n",
    "        # compute attn scores using einsum to do q@k.t within each head now \n",
    "        scores = torch.einsum('bnqd,bnkd->bnqk', q, k) # [batch, nheads, seq, seq]\n",
    "        scores = scores / (self.head_dim ** 0.5) # scale by sqrt(d_k)\n",
    "        \n",
    "        # the dim=-1 below was always confusing to me when learning about softmax\n",
    "        # since we want to normalize each row \n",
    "        # the intuition is that to normalize each row, we want the COLUMNS to sum to 1\n",
    "        # hence should tell torch to softmax over the cols, which here is the last dim of A \n",
    "        # since A is [B, S, S] \n",
    "        \n",
    "        A = F.softmax(scores, dim=-1)\n",
    "        \n",
    "        # apply attention: multiply attention weights with values\n",
    "        # A is [batch, n_heads, seq, seq]\n",
    "        # v is [batch, n_heads, seq, head_dim]\n",
    "        # we want out to be [batch, n_heads, seq, head_dim]\n",
    "        out = torch.einsum('bnqk,bnkd->bnqd', A, v)\n",
    "\n",
    "        # A is BSS and v is BSD\n",
    "        # A@v is BSD \n",
    "        # A is queries (rows) by keys (cols), so want all rows to sum to 1\n",
    "        # because a query can only pay unit attention over all keys behind it\n",
    "        \n",
    "        # reshape from [batch, n_heads, seq, head_dim] back to [batch, seq, dim]\n",
    "        out = out.view(B, S, D)\n",
    "        \n",
    "        return self.wo(out)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All tests passed!\n"
     ]
    }
   ],
   "source": [
    "# test it\n",
    "if __name__ == \"__main__\":\n",
    "    torch.manual_seed(42)\n",
    "    \n",
    "    batch_size = 2\n",
    "    seq_len = 4\n",
    "    dim = 128\n",
    "    head_dim = 32\n",
    "    \n",
    "    x = torch.randn(batch_size, seq_len, dim)\n",
    "    mhsa = MHSA(D=dim, head_dim=head_dim)\n",
    "    output = mhsa(x)\n",
    "    \n",
    "    # basic shape tests\n",
    "    assert output.shape == (batch_size, seq_len, dim), f\"Expected shape {(batch_size, seq_len, dim)} but got {output.shape}\"\n",
    "    \n",
    "    # test attention weights sum to 1\n",
    "    q, k, v = mhsa.wq(x), mhsa.wk(x), mhsa.wv(x)\n",
    "    q = q.view(batch_size, mhsa.n_heads, seq_len, head_dim)\n",
    "    k = k.view(batch_size, mhsa.n_heads, seq_len, head_dim)\n",
    "    scores = torch.einsum('bnqd,bnkd->bnqk', q, k) / (head_dim ** 0.5)\n",
    "    attn_weights = F.softmax(scores, dim=-1)\n",
    "    \n",
    "    # check if attention weights sum to 1 (with small numerical tolerance)\n",
    "    assert torch.allclose(attn_weights.sum(dim=-1), torch.ones_like(attn_weights.sum(dim=-1)), atol=1e-6)\n",
    "    \n",
    "    print(\"All tests passed!\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "envi",
   "language": "python",
   "name": "lingua_env"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
